{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "04185.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/04185.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tajOfFIr94mW",
        "colab_type": "text"
      },
      "source": [
        "## 5.5 卷积神经网络 ( LeNet )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Kx4Qj6nN-UZn",
        "colab_type": "text"
      },
      "source": [
        "### 5.5.1 LeNet 模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Bd_YeneP-ZLl",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch \n",
        "from torch import nn, optim \n",
        "import d2l \n",
        "import time \n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "\n",
        "\n",
        "class LeNet(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(LeNet, self).__init__()\n",
        "        self.conv = nn.Sequential(\n",
        "            nn.Conv2d(1, 6, 5), \n",
        "            nn.Sigmoid(), \n",
        "            nn.MaxPool2d(2, 2), \n",
        "            nn.Conv2d(6, 16, 5), \n",
        "            nn.Sigmoid(), \n",
        "            nn.MaxPool2d(2, 2), \n",
        "        )\n",
        "\n",
        "        self.fc = nn.Sequential(\n",
        "            nn.Linear(16*4*4, 120), \n",
        "            nn.Sigmoid(), \n",
        "            nn.Linear(120, 84), \n",
        "            nn.Sigmoid(), \n",
        "            nn.Linear(84, 10), \n",
        "        )\n",
        "\n",
        "    def forward(self, img):\n",
        "        feature = self.conv(img)\n",
        "        output = self.fc(feature.view(img.shape[0], -1))\n",
        "        return output"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EF1bBlntUfAv",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 302
        },
        "outputId": "a15c1561-ba42-4437-e466-53f3bec16bb5"
      },
      "source": [
        "net = LeNet()\n",
        "net"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "LeNet(\n",
              "  (conv): Sequential(\n",
              "    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))\n",
              "    (1): Sigmoid()\n",
              "    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
              "    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
              "    (4): Sigmoid()\n",
              "    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
              "  )\n",
              "  (fc): Sequential(\n",
              "    (0): Linear(in_features=256, out_features=120, bias=True)\n",
              "    (1): Sigmoid()\n",
              "    (2): Linear(in_features=120, out_features=84, bias=True)\n",
              "    (3): Sigmoid()\n",
              "    (4): Linear(in_features=84, out_features=10, bias=True)\n",
              "  )\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CeABh6h0UnBR",
        "colab_type": "text"
      },
      "source": [
        "### 5.5.2 获取数据和训练模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hRaBDF6EUwEv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "batch_size = 256 \n",
        "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BROEVPCuU5xh",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def evaluate_accuracy(data_iter, net, device=None):\n",
        "    if device is None and isinstance(net, torch.nn.Module):\n",
        "        device = list(net.parameters())[0].device \n",
        "    acc_sum, n = 0.0, 0 \n",
        "    with torch.no_grad():\n",
        "        for X, y in data_iter:\n",
        "            if isinstance(net, torch.nn.Module):\n",
        "                net.eval()\n",
        "                acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()\n",
        "                net.train()\n",
        "            else:\n",
        "                if('is_training' in net.__code__.co_varnames):\n",
        "                    acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item()\n",
        "                else:\n",
        "                    acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()\n",
        "            n += y.shape[0]\n",
        "    return acc_sum / n "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "T01cCqV2WotO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):\n",
        "    net = net.to(device)\n",
        "    print('training on', device)\n",
        "    loss = torch.nn.CrossEntropyLoss()\n",
        "    for epoch in range(num_epochs):\n",
        "        train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()\n",
        "        for X, y in train_iter:\n",
        "            X = X.to(device)\n",
        "            y = y.to(device)\n",
        "            y_hat = net(X)\n",
        "            l = loss(y_hat, y)\n",
        "            optimizer.zero_grad()\n",
        "            l.backward()\n",
        "            optimizer.step()\n",
        "            train_l_sum += l.cpu().item()\n",
        "            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()\n",
        "            n += y.shape[0]\n",
        "            batch_count += 1 \n",
        "        test_acc = evaluate_accuracy(test_iter, net)\n",
        "        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec' \n",
        "          % (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0XhbfHk8ZDim",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 118
        },
        "outputId": "93fd3465-4685-46d5-ed60-8c55df896899"
      },
      "source": [
        "lr, num_epochs = 0.001, 5 \n",
        "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
        "train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "training on cuda\n",
            "epoch 1, loss 1.9168, train acc 0.287, test acc 0.591, time 8.1 sec\n",
            "epoch 2, loss 0.9513, train acc 0.642, test acc 0.691, time 8.1 sec\n",
            "epoch 3, loss 0.7637, train acc 0.719, test acc 0.728, time 8.1 sec\n",
            "epoch 4, loss 0.6764, train acc 0.745, test acc 0.751, time 8.1 sec\n",
            "epoch 5, loss 0.6208, train acc 0.762, test acc 0.750, time 8.1 sec\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}